--- external/gsn/models/layers.py	2023-03-28 23:42:27.913662937 +0000
+++ external_reference/gsn/models/layers.py	2023-03-28 23:41:48.297127705 +0000
@@ -4,9 +4,31 @@
 from torch import nn
 import torch.nn.functional as F
 
-from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import bias_act
+from torch_utils import persistence
+
+# CHANGED: reimplement using torch_utils
+def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
+    x = bias_act.bias_act(input, bias, act='lrelu', alpha=negative_slope, gain=scale)
+    return x
+
+@persistence.persistent_class
+class FusedLeakyReLU(nn.Module):
+    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
+        super().__init__()
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(channel))
+        else:
+            self.bias = None
 
+        self.negative_slope = negative_slope
+        self.scale = scale
 
+    def forward(self, input):
+        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+@persistence.persistent_class
 class PixelNorm(nn.Module):
     """Pixel normalization layer.
 
@@ -20,6 +42,7 @@
         return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
 
 
+@persistence.persistent_class
 class ConstantInput(nn.Module):
     """Constant input layer.
 
@@ -58,6 +81,7 @@
     return k
 
 
+@persistence.persistent_class
 class Blur(nn.Module):
     """Blur layer.
 
@@ -81,7 +105,7 @@
     def __init__(self, kernel, pad, upsample_factor=1):
         super().__init__()
 
-        kernel = make_kernel(kernel)
+        kernel = upfirdn2d.setup_filter(kernel)
 
         if upsample_factor > 1:
             kernel = kernel * (upsample_factor ** 2)
@@ -90,10 +114,11 @@
         self.pad = pad
 
     def forward(self, input):
-        out = upfirdn2d(input, self.kernel, pad=self.pad)
+        out = upfirdn2d.upfirdn2d(input, self.kernel, padding=self.pad)
         return out
 
 
+@persistence.persistent_class
 class Upsample(nn.Module):
     """Upsampling layer.
 
@@ -112,19 +137,21 @@
         super().__init__()
 
         self.factor = factor
-        kernel = make_kernel(kernel) * (factor ** 2)
+        kernel = upfirdn2d.setup_filter(kernel)  # * (factor ** 2) is handled in upsampled2d gain argument
+        # kernel = make_kernel(kernel) * (factor ** 2)
         self.register_buffer("kernel", kernel)
 
-        p = kernel.shape[0] - factor
-        pad0 = (p + 1) // 2 + factor - 1
-        pad1 = p // 2
-        self.pad = (pad0, pad1)
+        # p = kernel.shape[0] - factor
+        # pad0 = (p + 1) // 2 + factor - 1
+        # pad1 = p // 2
+        self.pad = (0, 0) # upfirdn2d.upsample2d handles additional padding
 
     def forward(self, input):
-        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+        out = upfirdn2d.upsample2d(input, self.kernel, up=self.factor, padding=self.pad)
         return out
 
 
+@persistence.persistent_class
 class Downsample(nn.Module):
     """Downsampling layer.
 
@@ -143,19 +170,22 @@
         super().__init__()
 
         self.factor = factor
-        kernel = make_kernel(kernel)
+        kernel = upfirdn2d.setup_filter(kernel) # make_kernel(kernel)
         self.register_buffer("kernel", kernel)
 
+        # downsample needs padding
         p = kernel.shape[0] - factor
         pad0 = (p + 1) // 2
         pad1 = p // 2
         self.pad = (pad0, pad1)
+        # self.pad = (0, 0)
 
     def forward(self, input):
-        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+        out = upfirdn2d.upfirdn2d(input, self.kernel, up=1, down=self.factor, padding=self.pad)
         return out
 
 
+@persistence.persistent_class
 class EqualLinear(nn.Module):
     """Linear layer with equalized learning rate.
 
@@ -207,6 +237,7 @@
         return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
 
 
+@persistence.persistent_class
 class EqualConv2d(nn.Module):
     """2D convolution layer with equalized learning rate.
 
@@ -256,6 +287,7 @@
         )
 
 
+@persistence.persistent_class
 class EqualConvTranspose2d(nn.Module):
     """2D transpose convolution layer with equalized learning rate.
 
@@ -315,6 +347,7 @@
         )
 
 
+@persistence.persistent_class
 class ConvLayer2d(nn.Sequential):
     def __init__(
         self,
@@ -367,6 +400,7 @@
         super().__init__(*layers)
 
 
+@persistence.persistent_class
 class ConvResBlock2d(nn.Module):
     """2D convolutional residual block with equalized learning rate.
 
@@ -417,6 +451,7 @@
         return out
 
 
+@persistence.persistent_class
 class ModulationLinear(nn.Module):
     """Linear modulation layer.
 
@@ -497,6 +532,7 @@
         return out
 
 
+@persistence.persistent_class
 class ModulatedConv2d(nn.Module):
     """2D convolutional modulation layer.
 
@@ -631,6 +667,7 @@
         return out
 
 
+@persistence.persistent_class
 class ToRGB(nn.Module):
     """Output aggregation layer.
 
@@ -675,6 +712,7 @@
         return out
 
 
+@persistence.persistent_class
 class ConvRenderBlock2d(nn.Module):
     """2D convolutional neural rendering block.
 
@@ -742,6 +780,7 @@
         return x, rgb
 
 
+@persistence.persistent_class
 class PositionalEncoding(nn.Module):
     """Positional encoding layer.
 
